import argparse
import json
import os
import random
import time
import uuid

# from os import makedirs
from os.path import join

import numpy as np
import torch
import torchvision.models
import yaml
from torch import nn
from torch.optim import SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, MultiStepLR
from torch.utils.data import Subset
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100

from datasets.clothing1mpp import Clothing1mPP
from datasets.coresCIFAR import coresCIFAR10, coresCIFAR100
from loggers.base import Logger
from loggers.wandb import wandbLogger


def get_datasets(config):
    if config["data"]["dataset_name"] == "clothing1mpp":
        train_dataset_full = Clothing1mPP(
            config["data"]["root"], config["data"]["image_size"], split="train"
        )
        if config["data"]["imbalance_factor"] < 1.0:
            print("The imbalance factor is ", config["data"]["imbalance_factor"])
            train_set_clean = Clothing1mPP(
                config["data"]["root"],
                config["data"]["image_size"],
                split="train",
                clean_label=True,
            )
            imbalance_ids = train_set_clean.get_imbalance_ids(
                0,  # TODO: fix seed
                imbalance_factor=config["data"]["imbalance_factor"],
            )
            train_dataset = Subset(train_set_clean, imbalance_ids)
        else:
            if config["data"]["tiny"]:
                tiny_set_ids = train_dataset_full.get_tiny_ids(seed=0)
                tiny_train_set = Subset(train_dataset_full, tiny_set_ids)
                train_dataset = tiny_train_set
            else:
                train_dataset = train_dataset_full
        num_classes = train_dataset_full.num_classes
        test_dataset = Clothing1mPP(
            config["data"]["root"],
            config["data"]["image_size"],
            split="test",
            pre_load=train_dataset_full.data_package,
        )
        val_dataset = Clothing1mPP(
            config["data"]["root"],
            config["data"]["image_size"],
            split="val",
            pre_load=train_dataset_full.data_package,
        )
    elif "cifar" in config["data"]["dataset_name"].lower():
        train_transform = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        cifar10_n_label = torch.load(f"{os.getcwd()}/datasets/bin/CIFAR-10_human.pt")
        cifar100_n_label = torch.load(f"{os.getcwd()}/datasets/bin/CIFAR-100_human.pt")
        val_dataset = None
        if (
            "cifar10" in config["data"]["dataset_name"].lower()
            and "cifar100" not in config["data"]["dataset_name"].lower()
            and "cores" not in config["data"]["dataset_name"].lower()
        ):
            train_dataset = CIFAR10(
                config["data"]["root"],
                train=True,
                download=True,
                transform=train_transform,
            )
            test_dataset = CIFAR10(
                config["data"]["root"],
                train=False,
                download=True,
                transform=test_transform,
            )
            train_dataset.num_classes = 10
            test_dataset.num_classes = 10
            train_dataset.ori_targets = train_dataset.targets
            if "cifar10-n-aggre" in config["data"]["dataset_name"].lower():
                train_dataset.targets = cifar10_n_label["aggre_label"]
            elif "cifar10-n-worst" in config["data"]["dataset_name"].lower():
                train_dataset.targets = cifar10_n_label["worse_label"]
            elif "cifar10-n-random" in config["data"]["dataset_name"].lower():
                train_dataset.targets = cifar10_n_label["random_label1"]
            num_classes = 10
        elif "cifar100" in config["data"]["dataset_name"].lower():
            train_dataset = CIFAR100(
                config["data"]["root"],
                train=True,
                download=True,
                transform=train_transform,
            )
            test_dataset = CIFAR100(
                config["data"]["root"],
                train=False,
                download=True,
                transform=test_transform,
            )
            if "cifar100-n" in config["data"]["dataset_name"].lower():
                train_dataset.targets = cifar100_n_label["noisy_label"]
            train_dataset.num_classes = 100
            test_dataset.num_classes = 100
            num_classes = 100
        elif "corescifar10" in config["data"]["dataset_name"].lower():
            train_dataset = coresCIFAR10(
                root="./data/",
                download=False,
                train=True,
                transform=train_transform,
                noise_type=config["data"]["noise_type"],
                noise_rate=config["data"]["noise_rate"],
            )
            test_dataset = coresCIFAR10(
                root="./data/",
                download=False,
                train=False,
                transform=test_transform,
                noise_type=config["data"]["noise_type"],
                noise_rate=config["data"]["noise_rate"],
            )
            num_classes = 10
        elif "corescifar100" in config["data"]["dataset_name"].lower():
            train_dataset = coresCIFAR100(
                root="./data/",
                download=True,
                train=True,
                transform=train_transform,
                noise_type=config["data"]["noise_type"],
                noise_rate=config["data"]["noise_rate"],
            )
            test_dataset = coresCIFAR100(
                root="./data/",
                download=True,
                train=False,
                transform=test_transform,
                noise_type=config["data"]["noise_type"],
                noise_rate=config["data"]["noise_rate"],
            )
            num_classes = 100

    print(
        f"Train size: {len(train_dataset)}, Test size: {len(test_dataset)}, Val size: {len(val_dataset) if val_dataset else 0}"
    )
    return train_dataset, test_dataset, val_dataset, num_classes


def get_logger(config):
    if config["general"]["logger"]["name"] == "wandb":
        assert (
            config["general"]["logger"]["wandb_key"] != ""
        ), "The wandb key cannot be empty"
        logger = wandbLogger(
            config["general"]["logger"]["wandb_key"],
            config["general"]["logger"]["project_name"],
            config["general"]["logger"]["wandb_entity"],
            join(config["general"]["save_root"], config["general"]["run_name"]),
        )
    else:
        logger = Logger(
            join(config["general"]["save_root"], config["general"]["run_name"])
        )
    return logger


def get_model(config, num_classes):
    model_name = config["model"]["name"].lower()
    print(f"==> Building model {model_name}")
    try:
        pretrained_model = config["model"]["pretrained_model"]
        print(f"==> Loading weights {pretrained_model}")
    except KeyError:
        pretrained_model = None
    try:
        model = torchvision.models.__dict__[model_name](weights=pretrained_model)
    except AttributeError:
        raise AttributeError("This model is not supported")

    if "resnet" in model_name:
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif "vgg" in model_name:
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    elif "inception" in model_name:
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif "vit" in model_name:
        model.heads[-1] = nn.Linear(model.heads[-1].in_features, num_classes)
        model.num_classes = num_classes

    if config["model"]["cifar"]:
        if "resnet18" in model_name:
            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
            model.bn1 = nn.BatchNorm2d(64)
            model.relu = nn.ReLU(inplace=True)
            model.maxpool = nn.Identity()
        else:
            raise NotImplementedError("The other models have not been supported, yet")
    return model


def get_lr_scheduler(optimizer, params):
    """Get learning rate scheduler."""
    if "lr_scheduler" not in params:
        return None

    scheduler_name = params.lr_scheduler
    total_epochs = params.epochs

    if scheduler_name == "constant":
        return LambdaLR(optimizer, lambda _: 1.0)
    elif scheduler_name == "cosine":
        return CosineAnnealingLR(optimizer, T_max=total_epochs)
    elif scheduler_name == "step":
        return MultiStepLR(optimizer, params.decay_steps, gamma=params.gamma)
    elif scheduler_name == "linear":
        return LambdaLR(optimizer, lambda epoch: 1 - epoch / total_epochs)
    else:
        raise ValueError(f"Unknown scheduler: {scheduler_name}")


def get_optimizer(model, params):
    """Get optimizer."""
    if "optimizer" not in params:
        return None

    optimizer_name = params.optimizer
    lr = params.lr
    weight_decay = params.weight_decay
    momentum = params.momentum

    if optimizer_name == "sgd":
        return SGD(
            model.parameters(),
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            nesterov=False,
        )
    elif optimizer_name == "adamw":
        return AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unknown optimizer: {optimizer_name}")


def seed_everything(seed):
    """Seed everything for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Set deterministic options for cudnn backend
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def print_dict(d, col1_name="Parameter", col2_name="Value"):
    """Print dictionary as a table."""
    max_key_length = max(len(str(k)) for k in d.keys())
    max_val_length = max(len(str(v)) for v in d.values())

    # Adjust column widths based on header lengths
    max_key_length = max(max_key_length, len(col1_name))
    max_val_length = max(max_val_length, len(col2_name))

    # Function to draw horizontal line
    def draw_line():
        print(f"+{'-' * (max_key_length + 2)}+{'-' * (max_val_length + 2)}+")

    # Table top border
    draw_line()

    # Table header
    print(f"| {col1_name.ljust(max_key_length)} | {col2_name.ljust(max_val_length)} |")
    draw_line()

    # Table rows
    for k, v in d.items():
        print(f"| {str(k).ljust(max_key_length)} | {str(v).ljust(max_val_length)} |")

    # Table bottom border
    draw_line()


def print_title(title, banner_width=40):
    """Print title with banner."""
    print()
    print("=" * banner_width)
    centered_title = title.replace("_", " ").title().center(banner_width)
    print(centered_title)
    print("=" * banner_width)


def merge_dicts(list_of_dicts):
    """Merge a list of dictionaries.""" ""
    merged_dict = {}
    for d in list_of_dicts:
        merged_dict.update(d)
    return merged_dict


def save_as_yaml(data, path):
    """Save a dict like object as yaml."""
    with open(path, "w") as f:
        yaml.dump(to_regular_dict(data), f)


def to_regular_dict(obj):
    """Convert EasyDict to regular dict."""
    return json.loads(json.dumps(obj))


def avg_gap(unlearn_results, retrain_results):
    unlearn_cond = (
        "remain_acc" in unlearn_results
        and "unlearn_acc" in unlearn_results
        and "test_acc" in unlearn_results
        and "mia_confidence" in unlearn_results
    )
    retrain_cond = (
        "remain_acc" in retrain_results
        and "unlearn_acc" in retrain_results
        and "test_acc" in retrain_results
        and "mia_confidence" in retrain_results
    )
    if unlearn_cond and retrain_cond:
        abs_diff = [
            abs(unlearn_results[m] - retrain_results[m])
            for m in ["remain_acc", "unlearn_acc", "test_acc", "mia_confidence"]
        ]
        return sum(abs_diff) / len(abs_diff)
    else:
        print("Avg gap calculation failed because of missing metrics")
        return None


def merge_configs(base_config, override_config):
    for key, value in override_config.items():
        if (
            isinstance(value, dict)
            and key in base_config
            and isinstance(base_config[key], dict)
        ):
            # Recursively merge nested dictionaries
            merge_configs(base_config[key], value)
        else:
            # Update or add the key-value pair
            base_config[key] = value


def load_config(config_file):
    with open(config_file, "r") as file:
        config = yaml.safe_load(file)

    inherit_from = config.get("inherit_from")
    if inherit_from:
        # Load values from the inherited configuration file
        inherited_config = load_config(inherit_from)
        merge_configs(inherited_config, config)
        config = inherited_config

    return config


def print_config(config_dict, indent=0):
    for key, value in config_dict.items():
        if isinstance(value, dict):
            print(" " * indent + f"{key}:")
            print_config(value, indent + 4)
        else:
            print(" " * indent + f"{key}: {value}")


def update_save_dir(config):
    save_root = config["general"]["save_root"]
    dataset = config["data"]["dataset_name"]
    run_name = config["general"]["run_name"]
    model_name = config["model"]["name"]
    torch_seed = config["general"]["torch_seed"]
    np_seed = config["general"]["np_seed"]
    batch_size = config["train"]["batch_size"]
    lr = config["train"]["learning_rate"]

    imb_ratio = config["data"]["imbalance_factor"]
    noise_rate = config["data"]["noise_rate"]
    config["general"]["save_root"] = join(
        save_root,
        f"{dataset}_{imb_ratio}_{noise_rate}",
        model_name,
        f"{batch_size}",
        f"{lr}",
        run_name,
        f"torch_seed_{torch_seed}",
        f"np_seed_{np_seed}",
    )
    config["general"]["save_model_dir"] = join(config["general"]["save_root"], "model")


def remove_all_sub_files(path):
    """
    Removes all files in the given directory, but does not delete the directory itself.
    """
    # Check if the path is a directory
    if not os.path.isdir(path):
        print("The provided path is not a directory.")
        return

    # Iterate over all the files in the directory
    for filename in os.listdir(path):
        file_path = os.path.join(path, filename)
        try:
            # If it's a file, remove it
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                # If it's a directory, we remove its contents recursively
                remove_all_sub_files(file_path)
        except Exception as e:
            print(f"Failed to delete {file_path}. Reason: {e}")


def create_save_dir(config):
    # Create save dir
    print("==> Saving to:", config["general"]["save_root"])
    if not os.path.exists(config["general"]["save_root"]):
        os.makedirs(config["general"]["save_root"])
    else:
        # Check if the directory is empty or if whip_existing_files is True
        if (
            os.listdir(config["general"]["save_root"])
            and not config["general"]["whip_existing_files"]
        ):
            print(
                "The directory '{}' is not empty. If you wish to automatically delete all existing files "
                "in this directory before saving new data, please set 'whip_existing_files' to True in the "
                "configuration file. You can do this by editing the config file and changing the line under "
                "[general] to 'whip_existing_files: True'. This action will remove all files in the specified "
                "directory, so please ensure you have backups of any important files before proceeding.".format(
                    config["general"]["save_root"]
                )
            )
            import sys

            sys.exit()  # Exit the program
        elif config["general"]["whip_existing_files"]:
            # Clean up save dir
            print(
                "==> Whipping existing files in dir:",
                config["general"]["save_root"],
                "whip_existing_files: True !",
            )
            remove_all_sub_files(config["general"]["save_root"])

    if not os.path.exists(config["general"]["save_model_dir"]):
        os.makedirs(config["general"]["save_model_dir"])


def model2device(full_package):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # TODO Multi GPU, using torch.distributed would be better than torch.nn.DataParallel
    full_package["model"] = full_package["model"].to(device)
    full_package["device"] = device
    return full_package


def get_configs():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        type=str,
        default="/code/Clothing1MM/configs/Clothing1MPP/default.yaml",
        help="Path to the config file.",
    )
    parser.add_argument("--noise_rate", type=float, help="Set the noise rate.")
    parser.add_argument("--save_root", type=str, help="Set the save root.")
    parser.add_argument(
        "--save_every_epoch",
        type=int,
        help="Set to make the trainer save the checkpoint every epoch",
    )
    parser.add_argument(
        "--model_name", type=str, help="Set the model used for training"
    )
    parser.add_argument(
        "--learning_rate", type=float, help="Set the model used for training"
    )
    parser.add_argument("--epoch", type=int, help="Set the model used for training")
    parser.add_argument(
        "--imbalance_factor",
        type=float,
        help="The imbalance of clothing1MPP dataset",
        default=1.0,
    )
    parser.add_argument("--wandb_key", type=str, help="The wandb key", default=None)
    parser.add_argument(
        "--wandb_entity", type=str, help="The wandb entity", default=None
    )
    parser.add_argument("--run_name", type=str, help="The run name", default=None)
    parser.add_argument("--tiny", action="store_true")
    parser.add_argument("--seed", type=int, default=None, help="seed in the algorithm")

    args = parser.parse_args()

    timestamp = str(int(time.time()))  # [-6:]
    run_id = f"{timestamp}_{uuid.uuid4().hex[:6]}"

    print("==> Loading config:", args.config)
    config = load_config(args.config)
    config["general"]["run_id"] = run_id

    # Override configuration with command line argument
    if args.noise_rate:
        config["data"]["noise_rate"] = args.noise_rate
    if args.save_root:
        config["general"]["save_root"] = args.save_root
    if args.save_every_epoch:
        config["train"]["save_every_epoch"] = args.save_every_epoch
    if args.model_name:
        config["model"]["name"] = args.model_name
    if args.learning_rate:
        config["train"]["learning_rate"] = args.learning_rate
    if args.epoch:
        config["train"]["epoch"] = args.epoch
    if args.imbalance_factor < 1.0:
        assert (
            config["data"]["dataset_name"] == "clothing1mpp"
        ), "The other dataset has no attribute called imabalance"
        config["data"]["imbalance_factor"] = args.imbalance_factor
    if args.run_name:
        config["general"]["run_name"] = args.run_name

    if args.wandb_key:
        assert config["general"]["logger"]["name"] == "wandb"
        config["general"]["logger"]["wandb_key"] = args.wandb_key
    if args.wandb_entity:
        assert config["general"]["logger"]["name"] == "wandb"
        config["general"]["logger"]["wandb_entity"] = args.wandb_key
    else:
        config["general"]["logger"]["wandb_entity"] = None
    if args.tiny and config["data"]["dataset_name"] == "clothing1mpp":
        config["data"]["tiny"] = True
    if args.seed:
        assert args.seed >= 0, "The seed should be non-negative"
        config["general"]["np_seed"] = args.seed

    update_save_dir(config)

    print("==> Using config:")
    print_config(config)

    # Create save dir (mkdir)
    create_save_dir(config)

    # set CUDA_VISIBLE_DEVICES, Show CUDA devices,
    # mange_cuda_devices(config)

    # Save config
    config_file_path = os.path.join(config["general"]["save_root"], f"config.yaml")
    print("==> Saving run config to:", config_file_path)
    with open(config_file_path, "w") as file:
        yaml.dump(config, file, default_flow_style=False)

    return config


def save_train_package(train_package, file_name, keys_to_save=None):
    if keys_to_save is None:
        keys_to_save = ["config", "model", "optimizer", "epoch"]

    save_dict = {}
    for key in keys_to_save:
        save_dict[key] = train_package[key]
    torch.save(save_dict, file_name)


def load_train_package(file_name, keys_to_load=None, train_package=None):
    if keys_to_load is None:
        keys_to_load = ["config", "model", "optimizer", "epoch"]

    train_package_loaded = torch.load(file_name)
    for key in keys_to_load:
        train_package[key] = train_package_loaded[key]
    return train_package


def get_transform(dataset_name, image_size):
    if dataset_name.lower() == "cifar10":
        if image_size != 32:
            print("Warning: cifar10 image_size is not 32")
        train_transform = transforms.Compose(
            [
                transforms.RandomCrop(image_size, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
    elif dataset_name.lower() == "cifar100":
        if image_size != 32:
            print("Warning: cifar10 image_size is not 32")
        train_transform = transforms.Compose(
            [
                transforms.RandomCrop(image_size, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
                ),
            ]
        )

        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
                ),
            ]
        )
    else:
        train_transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size)
                ),  # Resize the image to the desired crop size
                transforms.RandomCrop(image_size, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)
                ),
            ]
        )

        test_transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size)
                ),  # Resize the image to the desired crop size
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)
                ),
            ]
        )

    return train_transform, test_transform
